[feat] JoyAI-JoyImage-Edit support#13444
Conversation
yiyixuxu
left a comment
There was a problem hiding this comment.
thanks for the PR! I left some initial feedbacks
| return x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw)) | ||
|
|
||
|
|
||
| class JoyImageEditTransformer3DModel(JoyImageTransformer3DModel): |
There was a problem hiding this comment.
ohh what's going on here? is this some legancy code? can we remove?
There was a problem hiding this comment.
We first developed JoyImage, and then trained JoyImage-Edit based on it. This Transformer 3D model belongs to JoyImage, and JoyImage-Edit is inherited from JoyImage. We will also open-source JoyImage in the future.
They essentially share similar Transformer 3D models. I understand that each pipeline requires a specific Transformer model, which is why we implemented inheritance in this way.
| img_qkv = self.img_attn_qkv(img_modulated) | ||
| img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) | ||
| img_q = self.img_attn_q_norm(img_q).to(img_v) | ||
| img_k = self.img_attn_k_norm(img_k).to(img_v) | ||
| if vis_freqs_cis is not None: | ||
| img_q, img_k = apply_rotary_emb(img_q, img_k, vis_freqs_cis, head_first=False) | ||
|
|
||
| txt_modulated = modulate(self.txt_norm1(txt), shift=txt_mod1_shift, scale=txt_mod1_scale) | ||
| txt_qkv = self.txt_attn_qkv(txt_modulated) | ||
| txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) | ||
| txt_q = self.txt_attn_q_norm(txt_q).to(txt_v) | ||
| txt_k = self.txt_attn_k_norm(txt_k).to(txt_v) | ||
| if txt_freqs_cis is not None: | ||
| txt_q, txt_k = apply_rotary_emb(txt_q, txt_k, txt_freqs_cis, head_first=False) | ||
|
|
||
| q = torch.cat((img_q, txt_q), dim=1) | ||
| k = torch.cat((img_k, txt_k), dim=1) | ||
| v = torch.cat((img_v, txt_v), dim=1) | ||
|
|
||
| attn = attention(q, k, v, attn_kwargs=attn_kwargs).flatten(2, 3) | ||
| img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :] |
There was a problem hiding this comment.
| img_qkv = self.img_attn_qkv(img_modulated) | |
| img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) | |
| img_q = self.img_attn_q_norm(img_q).to(img_v) | |
| img_k = self.img_attn_k_norm(img_k).to(img_v) | |
| if vis_freqs_cis is not None: | |
| img_q, img_k = apply_rotary_emb(img_q, img_k, vis_freqs_cis, head_first=False) | |
| txt_modulated = modulate(self.txt_norm1(txt), shift=txt_mod1_shift, scale=txt_mod1_scale) | |
| txt_qkv = self.txt_attn_qkv(txt_modulated) | |
| txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) | |
| txt_q = self.txt_attn_q_norm(txt_q).to(txt_v) | |
| txt_k = self.txt_attn_k_norm(txt_k).to(txt_v) | |
| if txt_freqs_cis is not None: | |
| txt_q, txt_k = apply_rotary_emb(txt_q, txt_k, txt_freqs_cis, head_first=False) | |
| q = torch.cat((img_q, txt_q), dim=1) | |
| k = torch.cat((img_k, txt_k), dim=1) | |
| v = torch.cat((img_v, txt_v), dim=1) | |
| attn = attention(q, k, v, attn_kwargs=attn_kwargs).flatten(2, 3) | |
| img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :] | |
| attn_output, text_attn_output = self.attn(...) |
can we refactor the attention implementation to follow diffusers style?
basically you need to move all the layers used in attention calculation here into a JoyImageAttention (similar to FluxAttention https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L275)
also create a JoyImageAttnProcessor (see FluxAttnProcessor as example, I think it is same) https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L75 )
There was a problem hiding this comment.
Thanks for the reminder. I'll clean up this messy code.
| class ModulateX(nn.Module): | ||
| def __init__(self, hidden_size: int, factor: int, dtype=None, device=None): | ||
| super().__init__() | ||
| self.factor = factor | ||
|
|
||
| def forward(self, x: torch.Tensor): | ||
| if len(x.shape) != 3: | ||
| x = x.unsqueeze(1) | ||
| return [o.squeeze(1) for o in x.chunk(self.factor, dim=1)] |
There was a problem hiding this comment.
| class ModulateX(nn.Module): | |
| def __init__(self, hidden_size: int, factor: int, dtype=None, device=None): | |
| super().__init__() | |
| self.factor = factor | |
| def forward(self, x: torch.Tensor): | |
| if len(x.shape) != 3: | |
| x = x.unsqueeze(1) | |
| return [o.squeeze(1) for o in x.chunk(self.factor, dim=1)] |
| class ModulateDiT(nn.Module): | ||
| def __init__(self, hidden_size: int, factor: int, act_layer=nn.SiLU, dtype=None, device=None): | ||
| factory_kwargs = {"dtype": dtype, "device": device} | ||
| super().__init__() | ||
| self.factor = factor | ||
| self.act = act_layer() | ||
| self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs) | ||
| nn.init.zeros_(self.linear.weight) | ||
| nn.init.zeros_(self.linear.bias) | ||
|
|
||
| def forward(self, x: torch.Tensor): | ||
| return self.linear(self.act(x)).chunk(self.factor, dim=-1) |
There was a problem hiding this comment.
| class ModulateDiT(nn.Module): | |
| def __init__(self, hidden_size: int, factor: int, act_layer=nn.SiLU, dtype=None, device=None): | |
| factory_kwargs = {"dtype": dtype, "device": device} | |
| super().__init__() | |
| self.factor = factor | |
| self.act = act_layer() | |
| self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs) | |
| nn.init.zeros_(self.linear.weight) | |
| nn.init.zeros_(self.linear.bias) | |
| def forward(self, x: torch.Tensor): | |
| return self.linear(self.act(x)).chunk(self.factor, dim=-1) |
is ModulateWan is one used in the model? if so let's remove the ModulateDit and ModulateX
| head_dim = hidden_size // heads_num | ||
| mlp_hidden_dim = int(hidden_size * mlp_width_ratio) | ||
|
|
||
| self.img_mod = load_modulation(self.dit_modulation_type, hidden_size, 6, **factory_kwargs) |
There was a problem hiding this comment.
| self.img_mod = load_modulation(self.dit_modulation_type, hidden_size, 6, **factory_kwargs) | |
| self.img_mod = JoyImageModulate(...) |
let's remove the load_modulation function and use the layer directly, better to rename to JoyImageModulate too
There was a problem hiding this comment.
Ok, I will refactor modulation and use ModulateWan
New `model="joyai-edit"` on /v1/image-edit and /v2/image-edit, routed to a separate FastAPI sidecar on 127.0.0.1:8092 that runs JoyImageEditPipeline from the Moran232/diffusers fork + transformers 4.57.1. Process isolation needed because the fork's diffusers core registry patches cannot be vendored (PR huggingface/diffusers#13444 pending) and transformers 4.57.x is incompatible with our 5.3.0 stack. Phase 0 VRAM measurement: 50.3 GB resident, 65.5 GB peak reserved at 1024² / 30 steps (well under the 80 GB gate). Passed. - `joyai_client.py` (NEW, 167 lines): thin httpx wrapper with per-call short-lived AsyncClient, split timeouts (180s edit / 60s mgmt), HTTPStatus→JoyAIError mapping. Singleton `joyai` exported. - `config.py`: `JOYAI_SIDECAR_URL` (default http://127.0.0.1:8092) and `LOAD_JOYAI` env flag. Off by default. - `server.py`: three-tenant swap protocol replaces the two-tenant v1.1.4 helpers. New `_last_gpu_tenant` tracker + `_evict_other_tenants(new)` helper. All three `_ensure_*_ready()` helpers are now `async def` — 13 call sites updated across _dispatch_job and v1 sync handlers. IMAGE_EDIT dispatch arm routes `model=="joyai-edit"` to joyai_client; validates len(image_paths)==1 (422 otherwise). Lifespan health-probes the sidecar when LOAD_JOYAI=1 (non-blocking — joyai-edit returns 503 if unreachable). - `flux_manager.py`: pre-existing bug fix — _edit() hardcoded ensure_model("flux2-klein"), silently ignoring the dispatcher's `model` kwarg. Now accepts and respects `model`. Guidance_scale is now conditional on model != "flux2-klein" (Klein strips CFG, Dev uses it). - `tests/test_joyai_client.py` (NEW, 7 tests) + `tests/test_validation.py` (+3 tests): 89 tests passing (was 79). - Docs: API.md, QUICKSTART.md, README.md, CLAUDE.md, AGENTS.md all updated with joyai-edit model entry, three-tenant swap diagram, latency table, sidecar location/port, LOAD_JOYAI env var, v1.1.8 changelog entry. Out-of-tree (not committed here, installed separately): /mnt/nvme-1/servers/joyai-sidecar/ (sidecar venv + sidecar.py + run.sh) ~/.config/systemd/user/joyai-sidecar.service Smoke-tested end-to-end: upload → /v2/image-edit joyai-edit → SSE stream (phase denoising → encoding → None) → fetch WEBP result (352 KB, 91 s wall clock for 20 steps at 1024²). Three-tenant swap evicted LTX and reloaded it cleanly via _evict_other_tenants. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
| self.args = SimpleNamespace( | ||
| enable_activation_checkpointing=enable_activation_checkpointing, | ||
| is_repa=is_repa, | ||
| repa_layer=repa_layer, | ||
| ) | ||
|
|
There was a problem hiding this comment.
| self.args = SimpleNamespace( | |
| enable_activation_checkpointing=enable_activation_checkpointing, | |
| is_repa=is_repa, | |
| repa_layer=repa_layer, | |
| ) |
I think we can use self.config here (e.g. self.config.is_repa, self.config.repa_layer, etc.) instead of needing to define a separate namespace.
There was a problem hiding this comment.
Was the repa logic removed because it is not used in inference?
| timesteps: List[int] = None, | ||
| sigmas: List[float] = None, |
There was a problem hiding this comment.
| timesteps: List[int] = None, | |
| sigmas: List[float] = None, | |
| timesteps: list[int] | None = None, | |
| sigmas: list[float] | None = None, |
nit: could we switch to Python 3.9+ style implicit type hints here and elsewhere?
dg845
left a comment
There was a problem hiding this comment.
Thanks for the PR! Left an initial design review :).
|
@yiyixuxu @dg845 Specifically, I refactored the attention module. However, since the weight key names in the Diffusers model are already fixed, I didn't change the actual keys in the attention part. Additionally, I will consider refactoring the image pre-processing logic, since the logic is quite complex, I directly copied it over from the training code. If you have any further suggestions, please feel free to share. Thank you so much! |
| # ---- joint attention (fused QKV, directly on the block) ---- | ||
| # image attention layers | ||
| self.img_attn_qkv = nn.Linear(dim, inner_dim * 3, bias=True) | ||
| self.img_attn_q_norm = nn.RMSNorm(attention_head_dim, eps=eps) |
There was a problem hiding this comment.
If I remember correctly, the attention sublayer used to use the custom RMSNorm module, which upcasted to FP32 during the RMS computation. Here we're using torch.nn.RMSNorm, which doesn't. Is this intentional?
|
@bot /style |
|
Style fix is beginning .... View the workflow run here. |
| if negative_prompt is None and negative_prompt_embeds is None: | ||
| if num_items <= 1: | ||
| negative_prompt = ["<|im_start|>user\n<|im_end|>\n"] * batch_size | ||
| else: |
There was a problem hiding this comment.
nit: my understanding is that wrapping the edit instructions (e.g. "Add wings to the astronaut.") with the Qwen3-VL template is important for sample quality, as seen in #13444 (comment). So I think it would be more user-friendly to automatically wrap the prompt with the template inside the pipeline like we do for negative_prompt here.
|
|
||
| return prompt_embeds, prompt_embeds_mask | ||
|
|
||
| def encode_prompt( |
There was a problem hiding this comment.
suggestion: I think the way encode_prompt is currently implemented is confusing as the code splits into two paths (the main encode_prompt logic and encode_prompt_multiple_images) which partially do the same thing. I think it would be more to clear to refactor encode_prompt to something like this:
def encode_prompt(...) -> tuple[torch.Tensor, torch.Tensor]:
# 1. Handle inputs
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
has_image_conditions = images is not None
# 2. Generate prompt embeddings if necessary using Qwen3VL tokenizer/processor
if prompt_embeds is None:
template_type = "multiple_images" if has_image_conditions else "image"
# _get_qwen_prompt_embeds is responsible for:
# 1. Creating the final templated prompt
# 2. Running the processor (or possibly tokenizer) to get the text encoder inputs
# 3. Running the text encoder and getting the right Qwen3-VL hidden_states
# 4. Any post-processing that's specific to the multiple or single image case
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, template_type, device)
# 3. Post-process prompt_embeds (common logic for both cases)
prompt_embeds = prompt_embeds[:, :max_sequence_length]
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
# Handle expanding to num_images_per_prompt in both cases
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
...
return prompt_embeds, prompt_embeds_maskI think ideally _get_qwen_prompt_embeds would handle both the "multiple_images" and "image" cases, but if they can't be combined cleanly we could have separate helpers for each case.
There was a problem hiding this comment.
Thanks for the suggestion. Current implementation comes from our training code, and we have a multi-image editing model under active development that relies on this structure. We’d prefer to keep the current approach for now to avoid disrupting that work.
| super().tearDown() | ||
|
|
||
| def get_dummy_components(self): | ||
| tiny_ckpt_id = "huangfeice/tiny-random-Qwen3VLForConditionalGeneration" |
There was a problem hiding this comment.
I think the current tiny Qwen3-VL testing checkpoint huangfeice/tiny-random-Qwen3VLForConditionalGeneration has the following issues:
- It's quite big (~5M params, ~19 MB), which makes the pipeline tests quite heavy, so I think we should try to reduce the size of this checkpoint. It looks like most of the parameters are in the input and output embeddings (e.g.
embed_tokens), so for example reducing thevocab_sizeshould help. - The checkpoint might be misconfigured: the model config defines a vision
patch_sizeof14, but the processor config defines aimage_processorpatch_sizeof16. I think this mismatch is causing some tests such astest_cfgto fail.
There was a problem hiding this comment.
Tiny-random-Qwen3VLForConditionalGeneration updated on huggingface. Config fixed in 9d9ef52.
dg845
left a comment
There was a problem hiding this comment.
Thanks for the refactor! I think the PR is close to merge. I left a few small comments; the most important one is #13444 (comment), as this causes some pipeline tests to fail. Also, can you run make style and make quality to fix the code style, and make fix-copies to fix any dummy objects or out-of-sync copies?
CC @yiyixuxu to take a look at the forward hook in JoyImageEditPipeline._get_last_decoder_hidden_states for transformers>=5 compatibility.
|
@dg845 @yiyixuxu
Looking forward to the merge! Wishing you a Happy International Workers’ Day! Here are some scripts: Inference Scriptsimport torch
from diffusers import JoyImageEditPipeline
from diffusers.utils import load_image
pipeline = JoyImageEditPipeline.from_pretrained(
"jdopensource/JoyAI-Image-Edit-Diffusers", torch_dtype=torch.bfloat16
)
pipeline.to("cuda")
img_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
image = load_image(img_path)
# edit
image = pipeline(
image=image,
prompt="Add wings to the astronaut.",
generator=torch.Generator("cuda").manual_seed(0),
guidance_scale=4.0,
).images[0]
image.save("edit.png")
# t2i
image = pipeline(
prompt="A toy astronaut with wings.",
generator=torch.Generator("cuda").manual_seed(0),
guidance_scale=4.0,
).images[0]
image.save("t2i.png")
# batch edit
output = pipeline(
image=[image, image],
prompt=["Add wings to the astronaut.", "Add halo to the astronaut."],
generator=torch.Generator("cuda").manual_seed(0),
guidance_scale=4.0,
)
output.images[0].save("result1.png")
output.images[1].save("result2.png") |
|
@bot /style |
|
Style fix is beginning .... View the workflow run here. |
|
@bot /style |
|
Style fix is beginning .... View the workflow run here. |
|
When I run the pytest tests/pipelines/joyimage/test_joyimage_edit.pythe following tests fail:
After debugging, my understanding is that the root cause of these test failures is that Block-level offloading does work for For now, I think we can either skip the tests or override them so that the Qwen3-VL Claude suggestion for overriding test_group_offloading_inference def test_group_offloading_inference(self):
# Qwen3VLForConditionalGeneration (the text encoder) is incompatible with leaf_level group
# offloading. Its Qwen3VLVisionModel.fast_pos_embed_interpolate reads
# `self.pos_embed.weight.device` to create intermediate tensors before the Embedding's
# pre_forward hook fires, so the intermediate tensors land on CPU while hidden_states
# (produced by the Conv3d patch_embed) land on CUDA, causing a device mismatch.
#
# block_level works correctly: since Qwen3VLForConditionalGeneration has no ModuleList as a
# direct child, the entire model forms one unmatched group that onloads atomically before any
# submodule code runs, so pos_embed.weight.device is CUDA by the time it is read.
#
# For leaf_level we therefore move the text encoder to the target device directly (the same
# pattern the base test already uses for the VAE) and only apply leaf_level offloading to
# the diffusers-native transformer.
if not self.test_group_offloading:
return
def create_pipe():
torch.manual_seed(0)
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
return pipe
def run_forward(pipe):
torch.manual_seed(0)
inputs = self.get_dummy_inputs(torch_device)
return pipe(**inputs)[0]
pipe = create_pipe().to(torch_device)
output_without_group_offloading = run_forward(pipe)
# block_level: the full text encoder becomes one group (no direct ModuleList children), so
# the atomc onload/offload is safe.
pipe = create_pipe()
for component_name in ["transformer", "text_encoder"]:
component = getattr(pipe, component_name, None)
if component is None:
continue
if hasattr(component, "enable_group_offload"):
component.enable_group_offload(
torch.device(torch_device), offload_type="block_level", num_blocks_per_group=1
)
else:
apply_group_offloading(
component,
onload_device=torch.device(torch_device),
offload_type="block_level",
num_blocks_per_group=1,
)
pipe.vae.to(torch_device)
output_with_block_level = run_forward(pipe)
# leaf_level: skip the text encoder (transformers model with device-dependent tensor
# creation) and move it to the target device directly.
pipe = create_pipe()
pipe.transformer.enable_group_offload(
torch.device(torch_device), offload_type="leaf_level"
)
pipe.text_encoder.to(torch_device)
pipe.vae.to(torch_device)
output_with_leaf_level = run_forward(pipe)
if torch.is_tensor(output_without_group_offloading):
output_without_group_offloading = output_without_group_offloading.detach().cpu().numpy()
output_with_block_level = output_with_block_level.detach().cpu().numpy()
output_with_leaf_level = output_with_leaf_level.detach().cpu().numpy()
self.assertTrue(np.allclose(output_without_group_offloading, output_with_block_level, atol=1e-4))
self.assertTrue(np.allclose(output_without_group_offloading, output_with_leaf_level, atol=1e-4))In the longer term (a separate PR), we can consider adding support for specifying modules that only support block-level offloading in e.g. |
|
Also, can you run |
| # Internal helpers | ||
| # ------------------------------------------------------------------ | ||
|
|
||
| def _get_last_decoder_hidden_states(self, forward_fn, **kwargs): |
There was a problem hiding this comment.
ohhhh
we will double-check with the transformers team
Description
We are the JoyAI Team, and this is the Diffusers implementation for the JoyAI-Image-Edit model.
GitHub Repository: [https://github.com/jd-opensource/JoyAI-Image]
Hugging Face Model: [https://huggingface.co/jdopensource/JoyAI-Image-Edit-Diffusers]
Original opensource weights [https://huggingface.co/jdopensource/JoyAI-Image-Edit]
Fixes #13430
Model Overview
JoyAI-Image is a unified multimodal foundation model for image understanding, text-to-image generation, and instruction-guided image editing. It combines an 8B Multimodal Large Language Model (MLLM) with a 16B Multimodal Diffusion Transformer (MMDiT).
Kye Features
Image edit examples